﻿using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Xml;
using HAPI;
using HuginApiAddonsCS.Extensions;
using System.Threading;

namespace HuginApiAddonsCS.Policy
{
	public class PolicyNode : ICloneable
	{
	    private string _nodeName = string.Empty;
        private List<string> _observations = new List<string>();
	    private PolicyNode _parentNode = null;
	    private DiscreteNode _decisionNode = null;
	    private DiscreteNode _observationNode = null;
        private Dictionary<string, PolicyNode> _childNodes = new Dictionary<string, PolicyNode>();
	    private string _currentObs = string.Empty;
	    private string _modelName = string.Empty;
	    private int _nodeCount = 1;

        /// <summary>
        /// Constructor for the root node
        /// </summary>
        /// <param name="decisionNode">decision node to determine best decision from</param>
        /// <param name="observationNode">observation node to add observations</param>
	    public PolicyNode(DiscreteNode decisionNode, DiscreteNode observationNode)
	    {
            _decisionNode = decisionNode;
            _observationNode = observationNode;

            if (_observationNode != null)
            {
                for (uint i = 0; i < _observationNode.GetNumberOfStates(); i++)
                {
                    _observations.Add(_observationNode.GetStateLabel(i));
                }
            }

            _decisionNode.GetHomeDomain().RetractFindings();
            _decisionNode.GetHomeDomain().Propagate(Domain.Equilibrium.H_EQUILIBRIUM_SUM, Domain.EvidenceMode.H_EVIDENCE_MODE_NORMAL);
            _nodeName = decisionNode.GetStateWithHighestReward();

            //Loop through the observations and add a node for each
            foreach (string observation in _observations)
            {
                _currentObs = observation;
                _childNodes.Add(observation, new PolicyNode((DiscreteNode)_decisionNode.GetNextTimeStepNode(), (DiscreteNode)_observationNode.GetNextTimeStepNode(),this));
            }
	    }

        /// <summary>
        /// Constructor for building policy nodes that are not the root node, i.e. have a parent
        /// </summary>
        /// <param name="decisionNode">decision node to determine best decision from</param>
        /// <param name="observationNode">observation node to obtain observations from</param>
        /// <param name="parent">parent node</param>
	    public PolicyNode(DiscreteNode decisionNode, DiscreteNode observationNode, PolicyNode parent)
        {
            _decisionNode = decisionNode;
            _observationNode = observationNode;
            _parentNode = parent;

            if (_observationNode != null)
            {
                for (uint i = 0; i < _observationNode.GetNumberOfStates(); i++)
                {
                    _observations.Add(_observationNode.GetStateLabel(i));
                }
            }

            _decisionNode.GetHomeDomain().RetractFindings();
            _parentNode.SetPolicyHistory(this);
            _decisionNode.GetHomeDomain().Propagate(Domain.Equilibrium.H_EQUILIBRIUM_SUM, Domain.EvidenceMode.H_EVIDENCE_MODE_NORMAL);
            _nodeName = _decisionNode.GetStateWithHighestReward();

            //Loop through the observations and add a node for each
            foreach (string observation in _observations)
            {
                _currentObs = observation;
                _childNodes.Add(observation, new PolicyNode((DiscreteNode)_decisionNode.GetNextTimeStepNode(), (DiscreteNode)_observationNode.GetNextTimeStepNode(), this));
            }
        }

        /// <summary>
        /// Constructs a policy branch from action observation histories
        /// </summary>
        /// <param name="observations"></param>
        /// <param name="history"></param>
        /// <param name="modelName"></param>
	    public PolicyNode(List<string> observations, List<string> history, string modelName)
	    {
	        _modelName = modelName;

            _nodeName = history.First();
            history.RemoveAt(0);

            if (history.Count > 0)
            {
                _observations = observations;
                string obs = history.First();
                history.RemoveAt(0);
                _childNodes.Add(obs, new PolicyNode(observations, history, modelName, this));
            }
	    }

        /// <summary>
        /// Constructor for learning from action observation pairs
        /// </summary>
        /// <param name="observations"></param>
        /// <param name="history"></param>
        /// <param name="modelName"></param>
        /// <param name="parent"></param>
        public PolicyNode(List<string> observations, List<string> history, string modelName, PolicyNode parent)
        {
            _modelName = modelName;
            _parentNode = parent;

            _nodeName = history.First();
            history.RemoveAt(0);

            if (history.Count > 0)
            {
                _observations = observations;
                string obs = history.First();
                history.RemoveAt(0);
                _childNodes.Add(obs, new PolicyNode(observations, history, modelName, this));
            }
        }

        /// <summary>
        /// Constructor used for cloning
        /// </summary>
        /// <param name="action">node name</param>
        /// <param name="observations">list of observations</param>
        /// <param name="modelName">model name</param>
        /// <param name="childNodes">child nodes</param>
	    public PolicyNode(string action, List<string> observations, string modelName, Dictionary<string, PolicyNode> childNodes)
        {
            _nodeName = action;
	        _childNodes = childNodes;
	        _observations = observations;
	        _modelName = modelName;

	        foreach (KeyValuePair<string, PolicyNode> keyValuePair in childNodes)
	        {
	            keyValuePair.Value.SetParent(this);
	        }
	    }

        /// <summary>
        /// Constructor used for random fill in
        /// </summary>
        /// <param name="action">action to become the node name/label</param>
        /// <param name="observations">list of potential observations</param>
        /// <param name="modelName">model name tree belgons to</param>
        /// <param name="parent">parent node</param>
	    public PolicyNode(string action, List<string> observations, string modelName, PolicyNode parent)
	    {
	        _nodeName = action;
	        _observations = observations;
	        _modelName = modelName;
	        _parentNode = parent;
	    }

        /// <summary>
        /// Sets the policy history of parents recursively
        /// </summary>
        /// <param name="childNode">child node calling</param>
	    public void SetPolicyHistory(PolicyNode childNode)
	    {
	        if (_parentNode != null)
	        {
	            _parentNode.SetPolicyHistory(this);
	        }

            _decisionNode.SelectState((uint)_decisionNode.GetStateIndex(_nodeName));
            _observationNode.SelectState((uint)_observationNode.GetStateIndex(GetObservationForChild(childNode)));
	    }

        /// <summary>
        /// Get the observation that resulted in a given child node
        /// </summary>
        /// <param name="childNode">Child node</param>
        /// <returns>The observation name that resulted in the child node</returns>
	    public string GetObservationForChild(PolicyNode childNode)
	    {
	        string observation = string.Empty;

	        foreach (KeyValuePair<string, PolicyNode> pair in _childNodes)
	        {
	            if (pair.Value == childNode)
	            {
	                observation = pair.Key;
	            }
	        }

            if (observation == string.Empty)
            {
                observation = _currentObs;
            }

            return observation;
	    }

        /// <summary>
        /// Used to output XML Files
        /// </summary>
        /// <param name="writer">writer to command with informaton about self</param>
	    public void ToXmlWriter(XmlWriter writer)
	    {
	        writer.WriteStartElement("Node");
            writer.WriteElementString("Action", _nodeName);

	        foreach (KeyValuePair<string, PolicyNode> pair in _childNodes)
	        {
	            writer.WriteStartElement("Child");
                writer.WriteElementString("Observation", pair.Key);

                pair.Value.ToXmlWriter(writer);

                writer.WriteEndElement();
	        }

            writer.WriteEndElement();
	    }

        /// <summary>
        /// Converts to a tikz string
        /// </summary>
        /// <returns></returns>
	    public string ToTikzString()
	    {
            StringBuilder sb = new StringBuilder();

	        if (_parentNode == null)
	        {
	            sb.Append(@"\node[action](P){");
	            sb.Append(_nodeName);
                sb.Append(@"}" + Environment.NewLine);
	        }
	        else
	        {
                sb.Append(@"node[action]{");
                sb.Append(_nodeName);
                sb.Append(@"}" + Environment.NewLine);
	        }

	        foreach (KeyValuePair<string, PolicyNode> pair in _childNodes)
	        {
	            sb.Append(@"child{" + Environment.NewLine);

	            sb.Append(pair.Value.ToTikzString());

                sb.Append(@"edge from parent[->]" + Environment.NewLine);
	            sb.Append(@"node [above]{");
	            sb.Append(pair.Key);
                sb.Append("}" + Environment.NewLine);
                sb.Append(@"}" + Environment.NewLine);
	        }

	        return sb.ToString();
	    }

        /// <summary>
        /// Used to output a string in the correct format for importing into i-DIDs later on
        /// </summary>
        /// <returns></returns>
	    public string ToTreeString(string tabs, int currentTree)
	    {
            StringBuilder sb = new StringBuilder();

            if (_parentNode == null)
            {
                if (currentTree == 0)
                {
                    sb.Append(String.Format("Horizon: {0}\n", GetTreeLength()));
                    sb.Append(String.Format("Observations: {0}\n", _observations.Count));
                }
                sb.Append(String.Format("Vector : {0} : -> act {1}", currentTree, _nodeName));
            }

            foreach (KeyValuePair<string, PolicyNode> childNode in _childNodes)
            {
                sb.Append(String.Format("\n{0} * obs {1} -> act {2}", tabs, childNode.Key, childNode.Value.GetNodeName()));
                sb.Append(childNode.Value.ToTreeString(tabs + " ", currentTree));
            }

	        return sb.ToString();
	    }

        /// <summary>
        /// Gets the length of the tree from this node onwards
        /// </summary>
        /// <returns>the tree length from this node onwards</returns>
	    public int GetTreeLength()
	    {
	        int max = 0;

	        foreach (KeyValuePair<string, PolicyNode> pair in _childNodes)
	        {
	            if (pair.Value.GetTreeLength() > max)
	            {
	                max = pair.Value.GetTreeLength();
	            }
	        }

	        return max + 1;
	    }

        /// <summary>
        /// Gets actions observation paris as a string for building a search string
        /// </summary>
        /// <param name="childNode">child node to determine the observation</param>
        /// <returns>output pair string</returns>
	    public string ToHistoryPair(PolicyNode childNode)
	    {
            StringBuilder sb = new StringBuilder();

	        if (_parentNode != null)
	        {
	            sb.Append(_parentNode.ToHistoryPair(this) + "-");
	        }

            sb.Append(_nodeName + "-");
	        sb.Append(GetObservationForChild(childNode));

	        return sb.ToString();
	    }

        /// <summary>
        /// Recursively writes output CSV files
        /// </summary>
        /// <param name="csvDirectory">directory to write csv files to</param>
	    public void WriteToCsvFile(string csvDirectory)
	    {
            StringBuilder sb = new StringBuilder();

	        if (_parentNode != null)
	        {
	            sb.Append(_parentNode.ToHistoryPair(this));
	        }

	        sb.Append("," + _nodeName + ",");

	        string file = csvDirectory + "/" + _decisionNode.GetName() + ".csv";
	        bool fileExists = File.Exists(file);

            using (StreamWriter outFileWriter = new StreamWriter(file, true))
            {
                if (!fileExists)
                {
                    outFileWriter.WriteLine(@"SearchString,Action,");
                }
                outFileWriter.WriteLine(sb.ToString());
            }

	        foreach (KeyValuePair<string, PolicyNode> pair in _childNodes)
	        {
	            pair.Value.WriteToCsvFile(csvDirectory);
	        }
	    }

	    /// <summary>
	    /// Gets a list of nodes for a given level from self
	    /// </summary>
	    /// <param name="level">levels to go down</param>
	    /// <returns>list of nodes</returns>
	    public List<PolicyNode> GetNodesForLevel(int level)
	    {
	        List<PolicyNode> levelNodes = new List<PolicyNode>();

	        if (level == 1)
	        {
	            levelNodes.Add(this);
	        }
	        else
	        {
	            foreach (KeyValuePair<string, PolicyNode> pair in _childNodes)
	            {
                    int nextLevel = level - 1;
	                levelNodes.AddRange(pair.Value.GetNodesForLevel(nextLevel));
	            }
	        }

	        return levelNodes;
	    }

	    /// <summary>
	    /// Gets the history of the given node as csv string of action observation pairs
	    /// by looping up through the tree
	    /// </summary>
	    /// <returns></returns>
	    public string GetNodeHistory(PolicyNode childNode)
	    {
	        string history = string.Empty;

	        if (_parentNode == null)
	        {
	            //history = _nodeName;
	            if (childNode != null)
	            {
	                history = "," + _nodeName + "," + GetObservationForChild(childNode);
	            }
	        }
	        else
	        {
	            if (childNode != null)
	            {
                    history = _parentNode.GetNodeHistory(this) + "," + _nodeName + "," + GetObservationForChild(childNode);
	            }
	            else
	            {
	                history = _parentNode.GetNodeHistory(this);
	            }
	        }

            //if (_parentNode != null)
            //{
            //    history = _parentNode.GetNodeHistory(this);
            //}

            //if (childNode != null)
            //{
            //    history = history + "," + _nodeName;
            //    history = history + "," + GetObservationForChild(childNode);
            //}

	        return history;
	    }

        /// <summary>
        /// Recursively loops down tree to find the action for a given history
        /// At each stage extract the first parts of the string to get the 
        /// action and observation, if action not equal own action then return
        /// empty string, if history is empty then assume that parent had removed
        /// the last of the string and return own action
        /// </summary>
        /// <param name="history">history to compare</param>
        /// <returns></returns>
	    public string GetActionForHistory(List<string> history)
	    {
            if (history.Count == 0)
            {
                return _nodeName;
            }

            if (history[0] != _nodeName)
            {
                return string.Empty;
            }
            else
            {
                history.RemoveAt(0);
                string observation = history[0];
                history.RemoveAt(0);
                return _childNodes[observation].GetActionForHistory(history);
            }

	        return string.Empty;
	    }

        /// <summary>
        /// Determines whether a history of action observation pairs is compatible 
        /// with the current node
        /// </summary>
        /// <param name="history">input history to compare with</param>
        /// <param name="modelName">model name</param>
        /// <returns></returns>
	    public bool IsHistoryCompatible(List<string> history, string modelName)
	    {
            if (history.First() != _nodeName || modelName != _modelName)
            {
                return false;
            }
            else
            {
                //if final node
                if (history.Count == 1)
                {
                    return true;
                }
                else
                {
                    //if not final node need to check if observation exists
                    history.RemoveAt(0);
                    if (_childNodes.ContainsKey(history.First()))
                    {
                        string obs = history.First();
                        history.RemoveAt(0);
                        return _childNodes[obs].IsHistoryCompatible(history, modelName);
                    }
                    else
                    {
                        //No child node set for observation so compatible
                        return true;
                    }
                }
            }

	        return false;
	    }

        /// <summary>
        /// Compares with history to compare if same histy ended here
        /// </summary>
        /// <param name="history">comparison history</param>
        /// <returns>true if same history, else false</returns>
        public bool HasSameHistory(List<string> history)
	    {
            if (history.Count > 1)
            {
                //Only compare for all but final node
                if (history.First() != _nodeName)
                {
                    return false;
                }
                else
                {
                    history.RemoveAt(0);
                    if (_childNodes.ContainsKey(history.First()))
                    {
                        string obs = history.First();
                        history.RemoveAt(0);
                        return _childNodes[obs].HasSameHistory(history);
                    }
                    else
                    {
                        //No child node set for observation so not same history
                        return false;
                    }
                }
            }
            else
            {
                //if final node in tree assume correct history got you here
                return true;
            }
	    }

        /// <summary>
        /// Merges a branch with a history, creates new parts of branches where required
        /// </summary>
        /// <param name="history">input history</param>
        /// <param name="modelName">model name</param>
        /// <returns>true if merge was successfull</returns>
	    public bool MergeWithHistory(List<string> history, string modelName)
	    {
            if (history.First() != _nodeName || modelName != _modelName)
            {
                return false;
            }
            else
            {
                //if final node
                _nodeCount ++;
                if (history.Count == 1)
                {
                    return true;
                }
                else
                {
                    //if not final node need to check if observation exists
                    history.RemoveAt(0);
                    string obs = history.First();
                    history.RemoveAt(0);

                    if (_childNodes.ContainsKey(obs))
                    {
                        return _childNodes[obs].MergeWithHistory(history, modelName);
                    }
                    else
                    {
                        _childNodes.Add(obs, new PolicyNode(_observations, history, modelName, this));
                        return true;
                    }
                }
            }

	        return false;
	    }

	    /// <summary>
	    /// Fills in missing branches with random actions
	    /// </summary>
	    /// <param name="rand">pre seeded random</param>
	    /// <param name="treeLength">max tree lenth, stop at zero</param>
	    /// <param name="actions">mlist of potential actions</param>
	    public void RandomFill(Random rand, int treeLength, List<string> actions)
	    {
	        if (treeLength != 1)
	        {
                foreach (string observation in _observations)
                {
                    if (!_childNodes.ContainsKey(observation))
                    {
                        //choose a random action
                        string action = actions[rand.Next(actions.Count)];
                        _childNodes.Add(observation, new PolicyNode(action, _observations, _modelName, this));
                        _childNodes[observation].RandomFill(rand, treeLength - 1, actions);
                    }
                    else
                    {
                        _childNodes[observation].RandomFill(rand, treeLength - 1, actions);
                    }
                }
	        }
	    }

        /// <summary>
        /// Random fill for when tree length is unknown
        /// </summary>
        /// <param name="rand">pre seeded random</param>
        /// <param name="actions">list of potential actions</param>
	    public void RandomFill(Random rand, List<string> actions)
        {
            int treeLength = GetTreeLength();
            RandomFill(rand, treeLength, actions);
        }

        /// <summary>
        /// Random fill when only potential actions are known
        /// </summary>
        /// <param name="actions">list of potential actions</param>
	    public void RandomFill(List<string> actions)
	    {
	        Random rand = new Random();
            RandomFill(rand, actions);
	    }

        /// <summary>
        /// Gets the occurances of each observation
        /// </summary>
        /// <returns>dictionary linking observations to their occurances</returns>
	    public Dictionary<string, int> GetObservationCounts()
	    {
	        Dictionary<string, int> observationCounts = new Dictionary<string, int>();

	        foreach (string observation in _observations)
	        {
	            int count = 0;
	            if (_childNodes.ContainsKey(observation))
	            {
	                count = _childNodes[observation].GetNodeCount();
	            }

                observationCounts.Add(observation, count);
	        }
	        return observationCounts;
	    }

        /// <summary>
        /// Compares two nodes recursively to determine if they are hoeffding compatible
        /// </summary>
        /// <param name="node">node to compare self with</param>
        /// <param name="epsilon">epsilon value as part of hoeffding test</param>
        /// <returns>true if compatible, false if not</returns>
	    public bool HoeffdingTest(PolicyNode node, double epsilon)
	    {
            if (_nodeName != node.GetNodeName())
            {
                return false;
            }

            if (_childNodes.Count() != 0)
            {
                Dictionary<string, int> selfObservationCounts = GetObservationCounts();
                Dictionary<string, int> nodeObservationCounts = node.GetObservationCounts();

                foreach (string observation in _observations)
                {
                    double left = Math.Abs(((double)selfObservationCounts[observation]/(double)GetNodeCount())-((double)nodeObservationCounts[observation]/(double)node.GetNodeCount()));

                    double right = Math.Sqrt(1d/GetNodeCount()) + Math.Sqrt(1d/node.GetNodeCount());
                    right = right*Math.Sqrt(0.5*Math.Log(2d/epsilon));

                    if (left > right)
                    {
                        return false;
                    }
                }

                //Loop through each observation and perform test on each child node
                foreach (string observation in _observations)
                {
                    PolicyNode selfChild = GetNodeForObservation(observation);
                    PolicyNode nodeChild = node.GetNodeForObservation(observation);

                    if (selfChild != null && nodeChild != null)
                    {
                        if (!selfChild.HoeffdingTest(nodeChild, epsilon))
                        {
                            return false;
                        }
                    }
                }
            }
            
            return true;
	    }

        /// <summary>
        /// An alternative to the hoeffding test, just compares the
        /// proportions of each child node with a simple comparisson
        /// </summary>
        /// <param name="node"></param>
        /// <param name="difference"></param>
        /// <returns></returns>
	    public bool SimilarNodeTest(PolicyNode node, double difference)
        {
            if (_nodeName != node.GetNodeName())
            {
                return false;
            }

            if (_childNodes.Count() != 0)
            {
                Dictionary<string, int> selfObservationCounts = GetObservationCounts();
                Dictionary<string, int> nodeObservationCounts = node.GetObservationCounts();

                foreach (string observation in _observations)
                {
                    double left = Math.Abs(((double)selfObservationCounts[observation] / (double)GetNodeCount()) - ((double)nodeObservationCounts[observation] / (double)node.GetNodeCount()));

                    if (left > difference)
                    {
                        return false;
                    }
                }

                //Loop through each observation and perform test on each child node
                foreach (string observation in _observations)
                {
                    PolicyNode selfChild = GetNodeForObservation(observation);
                    PolicyNode nodeChild = node.GetNodeForObservation(observation);

                    if (selfChild != null && nodeChild != null)
                    {
                        if (!selfChild.SimilarNodeTest(nodeChild, difference))
                        {
                            return false;
                        }
                    }
                }
            }

            return true;
        }

        /// <summary>
        /// Returns the child node for a given observation, null if no child exists for observation
        /// </summary>
        /// <param name="observation">observation</param>
        /// <returns>child node</returns>
	    public PolicyNode GetNodeForObservation(string observation)
	    {
	        if (_childNodes.ContainsKey(observation))
	        {
	            return _childNodes[observation];
	        }

            return null;
	    }

        /// <summary>
        /// Returns the name of the node
        /// </summary>
        /// <returns>NAme of the node</returns>
	    public string GetNodeName()
	    {
	        return _nodeName;
	    }

        /// <summary>
        /// Returns how often the node occured
        /// </summary>
        /// <returns></returns>
	    public int GetNodeCount()
	    {
	        return _nodeCount;
	    }

	    /// <summary>
	    /// Checks if node has any missing child nodes, if missing any
	    /// will return false
	    /// </summary>
	    /// <returns></returns>
	    public bool IsComplete()
	    {
	        foreach (string observation in _observations)
	        {
	            if (!_childNodes.ContainsKey(observation))
	            {
	                return false;
	            }
	        }
            return true;
	    }

        /// <summary>
        /// Merges node with another node by filling in missing child nodes with
        /// those from the other node
        /// </summary>
        /// <param name="node"></param>
	    public void MergeWithNode(PolicyNode node)
	    {
            foreach (string observation in _observations)
            {
                if (!_childNodes.ContainsKey(observation))
                {
                    _childNodes.Add(observation, (PolicyNode)node.GetNodeForObservation(observation).Clone());
                    _childNodes[observation].SetParent(this);
                }
            }
	    }

        /// <summary>
        /// Used to manually set the child nodes of a policy node
        /// </summary>
        /// <param name="childNodes">input child nodes</param>
	    public void SetChildNodes(Dictionary<string, PolicyNode> childNodes)
	    {
	        _childNodes = childNodes;
	    }

        /// <summary>
        /// Used to reset a nodes parent
        /// </summary>
        /// <param name="node">parent node to set</param>
	    public void SetParent(PolicyNode node)
	    {
	        _parentNode = node;
	    }

        #region ICloneable Members

        public object Clone()
        {
            Dictionary<string, PolicyNode> clonedChildNodes = new Dictionary<string, PolicyNode>();

            foreach (KeyValuePair<string, PolicyNode> keyValuePair in _childNodes)
            {
                clonedChildNodes.Add(keyValuePair.Key, (PolicyNode)keyValuePair.Value.Clone());
            }

            PolicyNode clonedNode = new PolicyNode(_nodeName, _observations, _modelName, clonedChildNodes);

            clonedNode.SetChildNodes(clonedChildNodes);

            return clonedNode;
        }

        #endregion
    }
}
